// Copyright by LunaSec (owned by Refinery Labs, Inc)
//
// Licensed under the Business Source License v1.1
// (the "License"); you may not use this file except in compliance with the
// License. You may obtain a copy of the License at
//
// https://github.com/lunasec-io/lunasec/blob/master/licenses/BSL-LunaTrace.txt
//
// See the License for the specific language governing permissions and
// limitations under the License.
package advisory

import (
	"encoding/json"
	"fmt"
	"reflect"
	"strconv"
	"strings"
	"time"

	"github.com/blang/semver/v4"
	"github.com/facebookincubator/nvdtools/cvss3"
	"github.com/rs/zerolog/log"
	"golang.org/x/exp/slices"

	"github.com/samber/lo"

	"github.com/lunasec-io/lunasec/lunatrace/bsl/ingest-worker/pkg/vulnerability/schema"
	"github.com/lunasec-io/lunasec/lunatrace/cli/pkg/util"
	"github.com/lunasec-io/lunasec/lunatrace/gogen/gql"
	"github.com/lunasec-io/lunasec/lunatrace/gogen/gql/types"
)

const (
	OSVDatabaseSpecific = "cwe_ids"
	CVSS3None           = "AV:N/AC:L/PR:N/UI:N/S:U/C:N/I:N/A:N"
)

func isReviewed(rawVuln *schema.OsvSchema) bool {
	// TODO: For now this only works for github
	return rawVuln.DatabaseSpecific["github_reviewed"] == true
}

// ('Critical','High','Medium','Low','Negligible','Unknown')
func getSeverityName(rawVuln *schema.OsvSchema) string {
	// TODO: For now this only works for github
	rawSeverity := rawVuln.DatabaseSpecific["severity"]

	switch rawSeverity {
	case "CRITICAL":
		return "Critical"
	case "HIGH":
		return "High"
	case "MODERATE":
		return "Medium"
	case "LOW":
		return "Low"
	// there is no negligible for github data, so skip that one
	default:
		return "Unknown"
	}
}

// MapOsvToGraphql converts an OSV formatted vulnerability to the Graphql representation of a vulnerability.
// Example OSV: https://github.com/github/advisory-database/blob/main/advisories/github-reviewed/2021/12/GHSA-jfh8-c2jp-5v3q/GHSA-jfh8-c2jp-5v3q.json
func MapOsvToGraphql(source string, rawVuln *schema.OsvSchema) (graphqlInsert *gql.Vulnerability_insert_input, err error) {
	osvSchema, err := serializeToJsonRawMessage(rawVuln)
	if err != nil {
		log.Error().
			Err(err).
			Str("id", rawVuln.Id).
			Msg("unable to serialize osv schema")
		return
	}

	databaseSpecific, err := serializeToJsonRawMessage(rawVuln.DatabaseSpecific)
	if err != nil {
		log.Error().
			Err(err).
			Str("id", rawVuln.Id).
			Msg("unable to serialize database specific info")
		return
	}

	modifiedTime, err := time.Parse(time.RFC3339, rawVuln.Modified)
	if err != nil {
		log.Error().
			Err(err).
			Str("id", rawVuln.Id).
			Msg("unable to parse modified time")
		return
	}

	cvssScore := mapCvssScore(rawVuln)
	severityName := getSeverityName(rawVuln)
	reviewed := isReviewed(rawVuln)
	now := time.Now()
	graphqlInsert = &gql.Vulnerability_insert_input{
		Modified:           util.Ptr(modifiedTime),
		Published:          parseOptionalTime(rawVuln.Published),
		Withdrawn:          parseOptionalTime(rawVuln.Withdrawn),
		Affected:           mapAffected(rawVuln.Affected),
		Credits:            mapCredits(rawVuln.Credits),
		Database_specific:  &databaseSpecific,
		Details:            rawVuln.Details,
		Equivalents:        mapEquivalents(rawVuln.Aliases),
		References:         mapReferences(rawVuln.References),
		Severities:         mapSeverities(source, rawVuln.Severity),
		Source:             util.Ptr(source),
		Source_id:          util.Ptr(rawVuln.Id),
		Summary:            rawVuln.Summary,
		Cvss_score:         cvssScore,
		Upstream_data:      util.Ptr(osvSchema),
		Reviewed_by_source: &reviewed,
		Severity_name:      &severityName,
		Last_fetched:       &now,
		Cwes:               mapCwes(rawVuln.DatabaseSpecific),
		Cve_id:             getCveId(rawVuln),
	}

	return
}

func mapCwes(databaseSpecific schema.OsvSchemaDatabaseSpecific) *gql.Vulnerability_vulnerability_cwe_arr_rel_insert_input {
	var cwes []*gql.Vulnerability_vulnerability_cwe_insert_input
	for key, value := range databaseSpecific {
		if key != OSVDatabaseSpecific {
			continue
		}

		cweIdInterface, ok := value.([]interface{})
		if !ok {
			log.Warn().
				Interface("value", value).
				Str("type", reflect.TypeOf(value).String()).
				Msg("unable to type cast value as interface array of cwe ids")
			return nil
		}

		if len(cweIdInterface) == 0 {
			return nil
		}

		for _, cwe := range cweIdInterface {
			cweId, cweOk := cwe.(string)
			if !cweOk {
				log.Warn().
					Interface("value", cwe).
					Str("type", reflect.TypeOf(cwe).String()).
					Msg("unable to type cast value as cwe id string")
				return nil
			}

			// remove prefix from string so that it can be parsed as an int
			cweId = strings.ReplaceAll(cweId, "CWE-", "")

			cweIdNum, err := strconv.Atoi(cweId)
			if err != nil {
				log.Warn().
					Interface("value", cwe).
					Msg("unable to parse cwe id to number")
				return nil
			}

			name := fmt.Sprintf("CWE-%d", cweIdNum)

			cwes = append(cwes, &gql.Vulnerability_vulnerability_cwe_insert_input{
				Cwe: &gql.Vulnerability_cwe_obj_rel_insert_input{
					Data: &gql.Vulnerability_cwe_insert_input{
						Id:                   util.Ptr(cweIdNum),
						Name:                 util.Ptr(name),
						Description:          util.Ptr(""),
						Extended_description: util.Ptr(""),
					},
					On_conflict: CweOnConflict,
				},
			})
		}
	}
	return &gql.Vulnerability_vulnerability_cwe_arr_rel_insert_input{
		Data:        cwes,
		On_conflict: VulnerabilityCweOnConflict,
	}
}

func cvssScoreSeverityRating(score float64) string {
	if score > 9.0 {
		return "critical"
	}
	if score > 7.0 {
		return "high"
	}
	if score > 4.0 {
		return "medium"
	}
	if score > 0.1 {
		return "low"
	}
	return "none"
}

func mapCvssScore(vuln *schema.OsvSchema) *float64 {
	if len(vuln.Severity) == 0 {
		return nil
	}

	var cvss3Severity *schema.OsvSchemaSeverityElem
	for _, severity := range vuln.Severity {
		if severity.Type == schema.OsvSchemaSeverityElemTypeCVSSV3 {
			cvss3Severity = &severity
		}
	}

	if cvss3Severity == nil {
		cvss3Severity = &vuln.Severity[0]
	}

	if cvss3Severity.Score == "" {
		return nil
	}

	cvssVec, err := cvss3.VectorFromString(cvss3Severity.Score)
	if err != nil {
		log.Warn().
			Err(err).
			Str("id", vuln.Id).
			Msg("unable to parse cvss severity")
		return nil
	}
	score := cvssVec.Score()
	return &score
}

func stringSlicePostgresqlArray(slice []string) string {
	return "{" + strings.Join(slice, ",") + "}"
}

func serializeToJsonRawMessage(i interface{}) (json.RawMessage, error) {
	marshalled, err := json.Marshal(i)
	if err != nil {
		return nil, err
	}
	return marshalled, nil
}

func getCveId(vuln *schema.OsvSchema) *string {
	var nilString *string = nil

	cveId, foundMatch := lo.Find(vuln.Aliases, func(alias string) bool {
		return strings.HasPrefix(alias, "CVE")
	})

	if !foundMatch {
		return nilString
	}
	return &cveId
}

// OSV affected packages are split across multiple entries in the affected array. These packages can be merged
// to make querying for affected versions more straight forward.
func dedupeAffectedUsingPackageName(affected []schema.OsvSchemaAffectedElem) []schema.OsvSchemaAffectedElem {
	affectedLookup := map[string]schema.OsvSchemaAffectedElem{}

	for _, affectedElem := range affected {
		affectedPackage := affectedElem.Package

		// TODO (cthompson) what do we want to do if no package is provided? should we just be ignoring it?
		if affectedPackage == nil {
			continue
		}

		lookupKey := affectedPackage.Ecosystem + affectedPackage.Name
		if existingElem, ok := affectedLookup[lookupKey]; ok {
			existingElem.Versions = append(existingElem.Versions, affectedElem.Versions...)
			existingElem.Ranges = append(existingElem.Ranges, affectedElem.Ranges...)
			affectedLookup[lookupKey] = existingElem
			continue
		}
		affectedLookup[lookupKey] = affectedElem
	}

	newAffected := make([]schema.OsvSchemaAffectedElem, 0, len(affectedLookup))
	for _, affectedElem := range affectedLookup {
		newAffected = append(newAffected, affectedElem)
	}
	return newAffected
}

func mapReferences(references []schema.OsvSchemaReferencesElem) *gql.Vulnerability_reference_arr_rel_insert_input {
	if len(references) == 0 {
		return nil
	}

	data := make([]*gql.Vulnerability_reference_insert_input, 0, len(references))
	for _, element := range references {
		data = append(data, mapReferenceElement(element))
	}
	return &gql.Vulnerability_reference_arr_rel_insert_input{
		Data:        data,
		On_conflict: VulnerabilityReferenceOnConflict,
	}
}

func mapReferenceElement(reference schema.OsvSchemaReferencesElem) *gql.Vulnerability_reference_insert_input {
	return &gql.Vulnerability_reference_insert_input{
		Type: util.Ptr(mapReferenceType(reference.Type)),
		Url:  util.Ptr(reference.Url),
	}
}

func mapSeverities(source string, severities []schema.OsvSchemaSeverityElem) *gql.Vulnerability_severity_arr_rel_insert_input {
	if len(severities) == 0 {
		return nil
	}

	data := make([]*gql.Vulnerability_severity_insert_input, 0, len(severities))
	for _, element := range severities {
		data = append(data, mapSeverityElement(source, element))
	}
	return &gql.Vulnerability_severity_arr_rel_insert_input{
		Data:        data,
		On_conflict: VulnerabilitySeverityOnConflict,
	}
}

func mapSeverityElement(source string, severity schema.OsvSchemaSeverityElem) *gql.Vulnerability_severity_insert_input {
	return &gql.Vulnerability_severity_insert_input{
		Score:  util.Ptr(severity.Score),
		Source: util.Ptr(source),

		// TODO (cthompson) it looks like for severity only CVSS_V3 is defined https://github.com/ossf/osv-schema/blob/main/validation/schema.json#L43
		// do we want to make the type column an enum?
		Type: util.Ptr(string(severity.Type)),
	}
}

func mapEquivalents(aliases []string) *gql.Vulnerability_equivalent_arr_rel_insert_input {
	if len(aliases) == 0 {
		return nil
	}

	data := make([]*gql.Vulnerability_equivalent_insert_input, 0, len(aliases))
	for _, element := range aliases {
		data = append(data, mapEquivalentElement(element))
	}
	return &gql.Vulnerability_equivalent_arr_rel_insert_input{
		Data:        data,
		On_conflict: VulnerabilityEquivalentOnConflict,
	}
}

func mapEquivalentElement(alias string) *gql.Vulnerability_equivalent_insert_input {
	source := determineSourceFromId(alias)
	var cveId *string
	if source == "nvd" {
		cveId = &alias
	}
	return &gql.Vulnerability_equivalent_insert_input{
		Vulnerability: nil,
		Equivalent_vulnerability: &gql.Vulnerability_obj_rel_insert_input{
			Data: &gql.Vulnerability_insert_input{
				Source:    util.Ptr(source),
				Source_id: util.Ptr(alias),
				Cve_id:    cveId,
			},
			On_conflict: VulnerabilityOnConflictAsEquivalentSubObject,
		},
	}
}

func determineSourceFromId(id string) string {
	if strings.HasPrefix(id, "CVE-") {
		return "nvd"
	}
	if strings.HasPrefix(id, "GHSA-") {
		return "github"
	}
	panic("unable to determine source from id: " + id)
}

func mapCredits(credits []schema.OsvSchemaCreditsElem) *gql.Vulnerability_credit_arr_rel_insert_input {
	if len(credits) == 0 {
		return nil
	}

	data := make([]*gql.Vulnerability_credit_insert_input, 0, len(credits))
	for _, element := range credits {
		data = append(data, mapCreditElement(element))
	}
	return &gql.Vulnerability_credit_arr_rel_insert_input{
		Data:        data,
		On_conflict: VulnerabilityCreditOnConflict,
	}
}

func mapCreditElement(credit schema.OsvSchemaCreditsElem) *gql.Vulnerability_credit_insert_input {
	formattedContact := stringSlicePostgresqlArray(credit.Contact)
	return &gql.Vulnerability_credit_insert_input{
		Contact: util.Ptr(formattedContact),
		Name:    util.Ptr(credit.Name),
	}
}

func parseOptionalTime(optionalTime *string) *time.Time {
	if optionalTime == nil {
		return nil
	}

	parsedTime, err := time.Parse(time.RFC3339, *optionalTime)
	if err != nil {
		panic(err)
	}
	return &parsedTime
}

func mapAffected(affecteds []schema.OsvSchemaAffectedElem) *gql.Vulnerability_affected_arr_rel_insert_input {
	if len(affecteds) == 0 {
		return nil
	}

	mergedAffected := dedupeAffectedUsingPackageName(affecteds)

	data := make([]*gql.Vulnerability_affected_insert_input, 0, len(mergedAffected))
	for _, affected := range mergedAffected {
		affectedElement := mapAffectedElement(affected)
		if affectedElement == nil {
			continue
		}
		data = append(data, affectedElement)
	}

	return &gql.Vulnerability_affected_arr_rel_insert_input{
		Data:        data,
		On_conflict: VulnerabilityAffectedOnConflict,
	}
}

func mapAffectedElement(element schema.OsvSchemaAffectedElem) *gql.Vulnerability_affected_insert_input {
	databaseSpecific, err := serializeToJsonRawMessage(element.DatabaseSpecific)
	if err != nil {
		log.Error().
			Err(err).
			Msg("failed to serialize database specific information")
		return nil
	}

	ecosystemSpecific, err := serializeToJsonRawMessage(element.EcosystemSpecific)
	if err != nil {
		log.Error().
			Err(err).
			Msg("failed to serialize ecosystem specific information")
		return nil
	}

	// TODO: These are legacy raw ranges, they exist in the database as a flat record of the OSV data, but the clients should prefer to use the above parsed ranges
	// TODO (cthompson) update package schema to accept repo
	rangeEvents, _ := mapAffectedRanges(element.Ranges)

	if element.Package == nil {
		panic("package is nil")
	}

	packageManager, err := MapStringToPackageManager(element.Package.Ecosystem)
	if err != nil {
		log.Error().
			Err(err).
			Str("ecosystem", element.Package.Ecosystem).
			Msg("failed to map ecosystem to package manager")
		return nil
	}

	affectedPackage := &gql.Package_obj_rel_insert_input{
		Data: &gql.Package_insert_input{
			// TODO (cthompson) have a default registry lookup table for package managers
			Custom_registry: util.Ptr(""),
			Name:            util.Ptr(element.Package.Name),
			Package_manager: util.Ptr(packageManager),
		},
		On_conflict: VulnerabilityPackageOnConflict,
	}

	return &gql.Vulnerability_affected_insert_input{
		Affected_range_events: rangeEvents,
		Affected_versions:     mapAffectedVersion(element.Versions),
		Database_specific:     util.Ptr(databaseSpecific),
		Ecosystem_specific:    util.Ptr(ecosystemSpecific),
		Package:               affectedPackage,
		Ranges:                buildRangeInsert(element.Ranges),
	}

}

func buildRangeInsert(rawRanges []schema.OsvSchemaAffectedElemRangesElem) (rangesInput *gql.Vulnerability_range_arr_rel_insert_input) {
	if len(rawRanges) == 0 {
		return nil
	}
	ranges := parseRanges(rawRanges)
	rangesInput = &gql.Vulnerability_range_arr_rel_insert_input{
		Data:        ranges,
		On_conflict: RangeOnConflict,
	}
	return rangesInput
}

// This is the replacement for the functions below. This maps the ranges in a flatter way, into a single table
// For now both sets of functions and tables exist in parallel for legacy reasons, but hopefully this table will be good enough that we can deprecate the other tabes
func parseRanges(rawRanges []schema.OsvSchemaAffectedElemRangesElem) (ranges []*gql.Vulnerability_range_insert_input) {

	existingRangeSlugs := make(map[string]bool)

	for _, rawRange := range rawRanges {
		var parsedRange gql.Vulnerability_range_insert_input
		slug := ""
		for _, event := range rawRange.Events {
			for eventPropName, eventPropValue := range event {
				versionString := getEventValue(eventPropValue)
				version, err := semver.ParseTolerant(versionString)
				sanitizedVersionString := version.String()
				if err != nil {
					continue
				}
				if eventPropName == "introduced" {
					parsedRange.Introduced = &sanitizedVersionString
					slug += "introduced:" + sanitizedVersionString
				}
				if eventPropName == "fixed" {
					parsedRange.Fixed = &sanitizedVersionString
					slug += "fixed:" + sanitizedVersionString
				}
			}
		}

		if existingRangeSlugs[slug] != true {
			existingRangeSlugs[slug] = true
			ranges = append(ranges, &parsedRange)
		}
	}
	return ranges
}

func mapAffectedRanges(ranges []schema.OsvSchemaAffectedElemRangesElem) (*gql.Vulnerability_affected_range_event_arr_rel_insert_input, string) {
	var repo string

	if len(ranges) == 0 {
		return nil, repo
	}

	// A lookup to determine if an event has already been added to the insertion.
	// This will prevent duplicate event inserts.
	eventLookup := map[string]bool{}

	data := make([]*gql.Vulnerability_affected_range_event_insert_input, 0, len(ranges))
	for _, rangeElement := range ranges {
		// TODO (cthompson) how should we handle multiple defined repos?
		if rangeElement.Repo != nil {
			repo = *rangeElement.Repo
		}
		data = append(data, mapAffectedRangeElement(eventLookup, rangeElement)...)
	}

	affectedRangeInput := &gql.Vulnerability_affected_range_event_arr_rel_insert_input{
		Data:        data,
		On_conflict: VulnerabilityAffectedRangeOnConflict,
	}
	return affectedRangeInput, repo
}

func mapAffectedRangeElement(
	eventLookup map[string]bool,
	element schema.OsvSchemaAffectedElemRangesElem,
) []*gql.Vulnerability_affected_range_event_insert_input {
	databaseSpecific, err := serializeToJsonRawMessage(element.DatabaseSpecific)
	if err != nil {
		panic(err)
	}

	rangeType := mapRangeEventTypeToAffectedRangeType(element.Type)

	var events []*gql.Vulnerability_affected_range_event_insert_input
	for _, event := range element.Events {
		for name, value := range event {
			eventValue := getEventValue(value)
			eventKey := name + eventValue

			if _, ok := eventLookup[eventKey]; ok {
				// this absolutely matches an existing event, skip adding this event again.
				continue
			}
			eventLookup[eventKey] = true

			rangeEvent := &gql.Vulnerability_affected_range_event_insert_input{
				Event:             util.Ptr(name),
				Type:              util.Ptr(rangeType),
				Version:           util.Ptr(eventValue),
				Database_specific: util.Ptr(databaseSpecific),
			}
			events = append(events, rangeEvent)
		}
	}
	return events
}

func getEventValue(value interface{}) string {
	switch t := value.(type) {
	case string:
		return t
	default:
		panic("unknown type for event value")
	}
}

func mapAffectedVersion(versions []string) *gql.Vulnerability_affected_version_arr_rel_insert_input {
	if len(versions) == 0 {
		return nil
	}

	// A lookup to determine if a version has already been inserted. This prevents inserting the same version twice.
	versionLookup := map[string]bool{}

	data := make([]*gql.Vulnerability_affected_version_insert_input, 0, len(versions))
	for _, version := range versions {
		if _, ok := versionLookup[version]; ok {
			continue
		}
		versionLookup[version] = true

		data = append(data, &gql.Vulnerability_affected_version_insert_input{
			Version: util.Ptr(version),
		})
	}
	return &gql.Vulnerability_affected_version_arr_rel_insert_input{
		Data:        data,
		On_conflict: VulnerabilityAffectedVersion,
	}
}

func mapReferenceType(osvReferenceType schema.OsvSchemaReferencesElemType) types.ReferenceType {
	normalizedOsvReferenceType := strings.ToLower(string(osvReferenceType))
	referenceType := types.ReferenceType(normalizedOsvReferenceType)
	if !slices.Contains(types.ReferenceTypes, referenceType) {
		panic("unable to determine reference type for " + normalizedOsvReferenceType)
	}
	return referenceType
}

func mapRangeEventTypeToAffectedRangeType(eventType schema.OsvSchemaAffectedElemRangesElemType) types.AffectedRangeType {
	normalizedEventType := strings.ToLower(string(eventType))
	rangeType := types.AffectedRangeType(normalizedEventType)
	if !slices.Contains(types.AffectedRangeTypes, rangeType) {
		panic("unable to find range type for " + normalizedEventType)
	}
	return rangeType
}
